{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UiNxsd4_q9wq"
   },
   "source": [
    "### What-If Tool Image Smile Detection\n",
    "\n",
    "In this demo we demonstrate the use of what-if-tool for image recognition models. Our task is to predict if a person is smiling or not. We provide a CNN that is trained on a subset of [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) and visualize the results on a separate test subset.\n",
    "\n",
    "Copyright 2019 Google LLC.\n",
    "SPDX-License-Identifier: Apache-2.0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure the right version of Tensorflow is installed.\n",
    "!pip freeze | grep tensorflow==2.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### **Re-run** the below cell until you see the output `Successfully installed h5py-2.10.0`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install version 2.10 of h5py\n",
    "import sys\n",
    "!{sys.executable} -m pip uninstall -y h5py\n",
    "!{sys.executable} -m pip install 'h5py < 3.0.0'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### **Note**: Please ignore any **incompatibility ERROR** that may appear for the packages visions as it will not affect the lab's functionality."
   ]
  },
  {
   "source": [
    "### In order to use the correct __h5py__ version, you will need to restart the notebook's kernel. To do this, select __Kernel__ > __Restart Kernel__ from the top menu."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Download the pretrained keras model files and subset of celeba images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "EBOHfrOP7Iy5"
   },
   "outputs": [],
   "source": [
    "!curl -L https://storage.googleapis.com/what-if-tool-resources/smile-demo/smile-colab-model.hdf5 -o ./smile-model.hdf5\n",
    "!curl -L https://storage.googleapis.com/what-if-tool-resources/smile-demo/test_subset.zip -o ./test_subset.zip\n",
    "!unzip -qq -o test_subset.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define helper functions for dataset conversion from csv to tf.Examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "4H2nX-2dEgsR"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import os\n",
    "from PIL import Image\n",
    "from io import BytesIO\n",
    "\n",
    "# Converts a dataframe into a list of tf.Example protos.\n",
    "# If images_path is specified, it assumes that the dataframe has a special \n",
    "# column \"image_id\" and the path \"images_path/image_id\" points to an image file.\n",
    "# Given this structure, this function loads and processes the images as png byte_lists\n",
    "# into tf.Examples so that they can be shown in WIT. Note that 'image/encoded'\n",
    "# is a reserved field in WIT for encoded image features.\n",
    "def df_to_examples(df, columns=None, images_path=''):\n",
    "  examples = []\n",
    "  if columns == None:\n",
    "    columns = df.columns.values.tolist()\n",
    "  for index, row in df.iterrows():\n",
    "    example = tf.train.Example()\n",
    "    for col in columns:\n",
    "      if df[col].dtype is np.dtype(np.int64):\n",
    "        example.features.feature[col].int64_list.value.append(int(row[col]))\n",
    "      elif df[col].dtype is np.dtype(np.float64):\n",
    "        example.features.feature[col].float_list.value.append(row[col])\n",
    "      elif row[col] == row[col]:\n",
    "        example.features.feature[col].bytes_list.value.append(row[col].encode('utf-8'))\n",
    "    if images_path:\n",
    "      fname = row['image_id']\n",
    "      with open(os.path.join(images_path, fname), 'rb') as f:\n",
    "        im = Image.open(f)\n",
    "        buf = BytesIO()\n",
    "        im.save(buf, format= 'PNG')\n",
    "        im_bytes = buf.getvalue()\n",
    "        example.features.feature['image/encoded'].bytes_list.value.append(im_bytes)\n",
    "    examples.append(example)\n",
    "  return examples\n",
    "\n",
    "# Converts a dataframe column into a column of 0's and 1's based on the provided test.\n",
    "# Used to force label columns to be numeric for binary classification using a TF estimator.\n",
    "def make_label_column_numeric(df, label_column, test):\n",
    "  df[label_column] = np.where(test(df[label_column]), 1, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load the csv file into pandas dataframe and process it for WIT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "MGpLKJI_HY9m"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "data = pd.read_csv('celeba/data_test_subset.csv')\n",
    "examples = df_to_examples(data, images_path='celeba/img_test_subset_resized/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load the keras models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "zZR3i6UZlZ96"
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.models import load_model\n",
    "\n",
    "model1 = load_model('smile-model.hdf5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define the custom predict function for WIT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "E5fYynA9ZpPJ"
   },
   "outputs": [],
   "source": [
    "# This function extracts 'image/encoded' field, which is a reserved key for the \n",
    "# feature that contains encoded image byte list. We read this feature into \n",
    "# BytesIO and decode it back to an image using PIL.\n",
    "# The model expects an array of images that are floats in range 0.0 to 1.0 and \n",
    "# outputs a numpy array of (n_samples, n_labels)\n",
    "def custom_predict(examples_to_infer):\n",
    "  def load_byte_img(im_bytes):\n",
    "    buf = BytesIO(im_bytes)\n",
    "    return np.array(Image.open(buf), dtype=np.float64) / 255.\n",
    "\n",
    "  ims = [load_byte_img(ex.features.feature['image/encoded'].bytes_list.value[0]) \n",
    "         for ex in examples_to_infer]\n",
    "  preds = model1.predict(np.array(ims))\n",
    "  return preds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ldvP-msGPnIv"
   },
   "source": [
    "## Note that this particular model only uses images as input. Therefore, partial dependence plots are flat for all features. These features are provided for slicing and analysis purposes.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Invoke What-If Tool for the data and model {display-mode: \"form\"}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "UwiWGrLlSWGh"
   },
   "outputs": [],
   "source": [
    "from witwidget.notebook.visualization import WitWidget, WitConfigBuilder, display\n",
    "\n",
    "num_datapoints = 250 \n",
    "tool_height_in_px = 700\n",
    "\n",
    "# Decode an image from tf.example bytestring\n",
    "def decode_image(ex):\n",
    "  im_bytes = ex.features.feature['image/encoded'].bytes_list.value[0]\n",
    "  im = Image.open(BytesIO(im_bytes))\n",
    "  return im\n",
    "\n",
    "# Define the custom distance function that compares the average color of images\n",
    "def image_mean_distance(ex, exs, params):\n",
    "  selected_im = decode_image(ex)\n",
    "  mean_color = np.mean(selected_im, axis=(0,1))\n",
    "  image_distances = [np.linalg.norm(mean_color - np.mean(decode_image(e), axis=(0,1))) for e in exs]\n",
    "  return image_distances\n",
    "\n",
    "# Setup the tool with the test examples and the trained classifier\n",
    "config_builder = WitConfigBuilder(examples[:num_datapoints]).set_custom_predict_fn(\n",
    "    custom_predict).set_custom_distance_fn(image_mean_distance)\n",
    "\n",
    "wv = WitWidget(config_builder, height=tool_height_in_px)\n",
    "display(wv)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "A1s1_SiOyS0l"
   },
   "source": [
    "#### Exploration ideas\n",
    "\n",
    "- In the \"Performance\" tab, set the ground truth feature to \"Smiling\". You can set a scatter axis or binning option to be inference correct and analyze how it varies across other features (i.e. you can make a scatter plot of Young vs inference correct).\n",
    "- Choose an image and click on \"Show nearest counterfactual datapoint\", this will find another example that is closest to the selected image in terms of average color value, but has a different prediction (if selected image is predicted to be \"smiling\" the counterfactual one will have \"not smiling\" prediction).\n",
    "- Define your own custom distance function and set it by calling set_custom_distance_fn on config_builder and explore the counterfactuals. You can even load another neural network to compute distances!\n",
    "- You can slice by any one of the features and analyze the confusion matrix and accuracy for each group.\n",
    "- In the \"Datapoint Editor\" tab, you can upload your own image or download and modify one of the images to see how it affects the inference score.\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "WIT_Smile_Detector.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
